{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Slowly Varying Regression under Sparsity\n",
    "\n",
    "This notebook contains a demo run for some of the algorithms described in the paper Slowly Varying Regression under Sparsity.\n",
    "\n",
    "We implement all algorithms in Julia programming language (version 1.6) and using the JuMP.jl modeling language for mathematical optimization (version 0.21). We solve the optimization models using the Gurobi commercial solver (version 9.5)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using JLD2\n",
    "\n",
    "include(\"src/svar_cutplane.jl\")\n",
    "include(\"src/svar_heuristic.jl\")\n",
    "include(\"src/regression_sum_of_norms.jl\")\n",
    "include(\"src/sparse_regression.jl\")\n",
    "include(\"src/cv.jl\")\n",
    "include(\"src/utils.jl\")\n",
    "include(\"src/svar_visualization.jl\")\n",
    ";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###\n",
    "# Select dataset\n",
    "###\n",
    "\n",
    "datapath = \"data/energy_hour.jld2\"\n",
    "# datapath = \"data/housing.jld2\"\n",
    "X, Y, edges = load(datapath, \"X\"), load(datapath, \"Y\"), load(datapath, \"edges\")\n",
    "\n",
    "\n",
    "###\n",
    "# Train-valid-test split\n",
    "###\n",
    "\n",
    "N_total,T,D = size(X)\n",
    "train_size = convert(Int, floor(N_total*0.6))\n",
    "valid_size = convert(Int, floor(N_total*0.2))\n",
    "test_size = convert(Int, floor(N_total*0.2))\n",
    "all_indices = collect(1:N_total)\n",
    "train_indices = sample(all_indices, train_size, replace=false)\n",
    "all_indinces = setdiff(all_indices, train_indices)\n",
    "valid_indices = sample(all_indices, valid_size, replace=false)\n",
    "all_indinces = setdiff(all_indices, valid_indices)\n",
    "test_indices = sample(all_indices, test_size, replace=false)\n",
    "\n",
    "X_train, Y_train = view(X,train_indices,:,:), view(Y,train_indices,:)\n",
    "X_valid, Y_valid = view(X,valid_indices,:,:), view(Y,valid_indices,:)\n",
    "X_test, Y_test = view(X,test_indices,:,:), view(Y,test_indices,:)\n",
    "\n",
    "X_train, Y_train, X_valid, Y_valid, X_test, Y_test = standardize_data(X_train, Y_train, X_valid, Y_valid, X_test, Y_test)\n",
    ";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Cross validation using svar_heuristic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###\n",
    "# Define search and constant hyperparameters\n",
    "###\n",
    "\n",
    "search_params = Dict(\n",
    "        :lambda_reg => 3,\n",
    "        :lambda_svar => 3,\n",
    "        # ENERGY\n",
    "        :sparsity => [10,15,20],\n",
    "        :global_sparsity_relative => [3,5],\n",
    "        :sparsely_varying => [20,30],\n",
    "#         # HOUSING\n",
    "#         :sparsity => [30,40,50,60],\n",
    "#         :global_sparsity_relative => [10,30],\n",
    "#         :sparsely_varying => [50,100],\n",
    "    )\n",
    "\n",
    "const_params = Dict(\n",
    "            :time_limit => 30.,\n",
    "            :verbose => 0,\n",
    "            :decrease_final_regularizer => true\n",
    "        )\n",
    "\n",
    "global DBG = stdout # print debug statement in stdout\n",
    "global DBG_list = [:progress] # only show progress debug statemets (e.g., hide memory related ones)\n",
    "\n",
    "###\n",
    "# Run cross validation\n",
    "###\n",
    "\n",
    "grid_search = holdout_grid_search(\n",
    "    X_train,\n",
    "    Y_train,\n",
    "    svar_heuristic,\n",
    "    search_params,\n",
    "    const_params,\n",
    "    X_valid=X_valid,\n",
    "    Y_valid=Y_valid,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###\n",
    "# Extract refit model and selected hyperparameter values\n",
    "###\n",
    "\n",
    "lnr_heuristic = grid_search.lnr_refit\n",
    "\n",
    "println(\"Heuristic sparsity params: $(eval_sparsity(lnr_heuristic.z, edges))\")\n",
    "println(\"Heuristic test r2: $(eval_r2(X_test,Y_test,lnr_heuristic.beta))\")\n",
    "\n",
    "sparsity = lnr_heuristic.lnr_params[:sparsity]\n",
    "global_sparsity = lnr_heuristic.lnr_params[:global_sparsity]\n",
    "sparsely_varying = lnr_heuristic.lnr_params[:sparsely_varying]\n",
    "lambda_svar = lnr_heuristic.lnr_params[:lambda_svar]\n",
    ";"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Run svar_cutplane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###\n",
    "# Run cutplane\n",
    "###\n",
    "\n",
    "lnr_cutplane = svar_cutplane(\n",
    "    vcat(X_train, X_valid), \n",
    "    vcat(Y_train, Y_valid),\n",
    "    edges=edges,\n",
    "    sparsity=sparsity,\n",
    "    global_sparsity=global_sparsity,\n",
    "    sparsely_varying=sparsely_varying,\n",
    "    lambda_svar=lambda_svar,\n",
    "    decrease_final_regularizer=true,\n",
    "    time_limit=300.,\n",
    "    verbose=1\n",
    ")\n",
    "\n",
    "println(\"Cutplane sparsity params: $(eval_sparsity(lnr_heuristic.z, edges))\")\n",
    "println(\"Cutplane test r2: $(eval_r2(X_test,Y_test,lnr_heuristic.beta))\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Visualize svar_cutplane "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coef_df = create_coef_df(\n",
    "    lnr_cutplane.beta,\n",
    ")\n",
    "\n",
    "g = plot_feature_variation(\n",
    "    coef_df,\n",
    "    edges,\n",
    "    fontsize=3,\n",
    "    method=:circular\n",
    ")\n",
    "\n",
    "display(g)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.6.5",
   "language": "julia",
   "name": "julia-1.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
